/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <dispenso/pool_allocator.h>

#include <deque>

#include <gtest/gtest.h>

TEST(PoolAllocator, SimpleMallocFree) {
  dispenso::PoolAllocator allocator(64, 256, ::malloc, ::free);

  char* buf = allocator.alloc();

  *buf = 'a';

  allocator.dealloc(buf);
}

TEST(PoolAllocator, TrackAllocations) {
  std::map<char*, size_t> allocMap;

  auto allocFunc = [&allocMap](size_t len) -> void* {
    char* ret = reinterpret_cast<char*>(::malloc(len));
    allocMap.emplace(ret, len);
    return ret;
  };

  auto deallocFunc = [&allocMap](void* ptr) {
    EXPECT_EQ(1, allocMap.erase(reinterpret_cast<char*>(ptr)));
    ::free(ptr);
  };

  // Check to make sure that the ptr returned by the allocator below is in one of the buffers
  // generated by allocFunc.  We do this by examining the closest buffer (via lower_bound) in the
  // map, and then verify that that buffer contains ptr.
  auto checkInValidRange = [&allocMap](char* ptr) {
    auto it = allocMap.upper_bound(ptr);
    --it;
    EXPECT_GE(ptr, it->first);
    EXPECT_LT(ptr, it->first + it->second);
    return ptr;
  };

  {
    dispenso::PoolAllocator allocator(64, 256, allocFunc, deallocFunc);

    char* bufs[5];

    bufs[0] = checkInValidRange(allocator.alloc());

    EXPECT_EQ(1, allocMap.size());

    bufs[1] = checkInValidRange(allocator.alloc());

    EXPECT_EQ(1, allocMap.size());

    allocator.dealloc(bufs[0]);

    EXPECT_EQ(1, allocMap.size());

    bufs[0] = checkInValidRange(allocator.alloc());

    EXPECT_EQ(1, allocMap.size());

    bufs[2] = checkInValidRange(allocator.alloc());

    EXPECT_EQ(1, allocMap.size());

    bufs[3] = checkInValidRange(allocator.alloc());

    EXPECT_EQ(1, allocMap.size());

    bufs[4] = checkInValidRange(allocator.alloc());

    EXPECT_EQ(2, allocMap.size());

    allocator.dealloc(bufs[4]);
    EXPECT_LE(2, allocMap.size());
  }

  EXPECT_EQ(allocMap.size(), 0);
}

TEST(PoolAllocator, SimpleThreaded) {
  constexpr size_t kNumThreads = 8;

  dispenso::PoolAllocator allocator(64, 256, ::malloc, ::free);

  std::deque<std::thread> threads;

  for (size_t i = 0; i < kNumThreads; ++i) {
    threads.emplace_back([&allocator, tid = i]() {
      constexpr size_t kNumBufs = 8;
      char* bufs[kNumBufs];

      for (size_t i = 0; i < 1000; ++i) {
        for (size_t j = 0; j < kNumBufs; ++j) {
          bufs[j] = allocator.alloc();
          *bufs[j] = static_cast<char>(tid);
        }
        for (size_t j = 0; j < kNumBufs; ++j) {
          EXPECT_EQ(*bufs[j], tid);
          allocator.dealloc(bufs[j]);
        }
      }
    });
  }

  for (auto& t : threads) {
    t.join();
  }
}

TEST(PoolAllocator, Arena) {
  dispenso::PoolAllocator allocator(64, 256, ::malloc, ::free);

  std::vector<char*> vec(2000);
  for (char*& c : vec) {
    c = allocator.alloc();
    std::fill_n(c, 64, 0x7f);
  }

  for (char* c : vec) {
    EXPECT_TRUE(std::all_of(c, c + 64, [](char v) { return v == 0x7f; }));
  }

  allocator.clear();
  vec.resize(128);
  for (char*& c : vec) {
    c = allocator.alloc();
    std::fill_n(c, 64, 0x22);
  }

  for (char* c : vec) {
    EXPECT_TRUE(std::all_of(c, c + 64, [](char v) { return v == 0x22; }));
  }

  allocator.clear();
  vec.resize(48);
  for (char*& c : vec) {
    c = allocator.alloc();
    std::fill_n(c, 64, 0x11);
  }

  for (char* c : vec) {
    EXPECT_TRUE(std::all_of(c, c + 64, [](char v) { return v == 0x11; }));
  }
}
